feat: per-task validation metrics in GRPO/Distillation, optional max_val_samples#2498
Closed
bzantium wants to merge 2 commits into
Closed
feat: per-task validation metrics in GRPO/Distillation, optional max_val_samples#2498bzantium wants to merge 2 commits into
bzantium wants to merge 2 commits into
Conversation
Two related changes to the validation truncation logic.
1. Make max_val_samples optional. When the field is absent or set to None
in the recipe, validate() now iterates the entire val_dataloader.
* GRPO already typed it as `int | None # None for NeMo-Gym
compatibility` but the main validation path crashed when reading
None. Patch the read site so the main path matches the type.
* Distillation widens the TypedDict from `int` to `NotRequired[int]`
and applies the same read-site change.
The exemplar YAMLs (examples/configs/grpo_math_1B.yaml and
examples/configs/distillation_math.yaml) keep their explicit values
so the recommended default is still documented.
2. Unify Distillation truncation with GRPO. GRPO uses floor division
(max_val_samples // val_batch_size); Distillation used ceiling
division ((max_val_samples + val_batch_size - 1) // val_batch_size).
With the new None-handling branch already in place, switch
Distillation to floor division so the two algorithms behave
identically when the field is set.
Behaviour impact for existing recipes: only Distillation runs whose
max_val_samples is not divisible by val_batch_size see fewer samples
evaluated by one partial batch. Recipes in examples/configs/recipes/llm
all use values that divide cleanly (256/8, 512/8 etc.), so no recipe
under examples/ is affected. Recipes that previously set an integer
that divides cleanly remain identical; recipes that previously omitted
the field could not run at all and now do.
Tests:
* tests/unit/algorithms/test_grpo.py adds
test_validate_iterates_full_dataloader_when_max_val_samples_is_none
* tests/unit/algorithms/test_distillation.py adds the same plus
test_validate_floor_divides_max_val_samples_by_val_batch_size to
guard the GRPO/Distillation parity.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>
Multi-validation (data.validation as a list of datasets) currently runs correctly but the validation aggregator collapses everything into a single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs math500) is silently lost. task_name is already on every sample (DatumSpec.task_name preserved through rl_collate_fn into val_batch["task_name"]); validate() simply did not read it. This commit teaches both validate() functions to track rewards per task during the loop, then emit accuracy_<task> and num_samples_<task> keys alongside the existing aggregated accuracy. logger.log_metrics plots each as its own metric automatically. The aggregated accuracy key is preserved unchanged for dashboard backwards compatibility. Datasets without task_name are skipped, so single-task and legacy recipes behave identically. DPO already does per-dataset metrics via its dict-of-dataloaders architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377), so it is not touched here. Tests: * test_grpo.py adds test_validate_emits_per_task_accuracy_keys. * test_distillation.py adds the same plus a check that the aggregated accuracy key matches the sample-weighted mean across tasks. Signed-off-by: Minho Ryu <ryumin93@gmail.com>
4 tasks
Author
|
Closed because the head branch on the fork was renamed ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Closes #2497.
Two related improvements to the
validate()function in GRPO and Distillation, bundled into one PR because they touch the same function in the same way.1. Per-task validation metrics (main change)
When
data.validationis configured as a list of multiple datasets the multi-validation path correctly loads them all and dispatches per-task to the right environment during rollout, but the validation aggregator collapses everything into a single sample-weightedaccuracyandavg_length. Per-task progress (e.g. gsm8k vs math500) is silently lost.task_nameis already on every sample (DatumSpec.task_name, preserved throughrl_collate_fnintoval_batch[\"task_name\"]);validate()simply did not read it. This PR teaches bothvalidate()functions to track rewards per task during the loop and emit:accuracy_<task>for each task seennum_samples_<task>for each task seenThe aggregated
accuracykey is preserved unchanged so existing dashboards continue to work. Single-task runs and legacy datasets withouttask_nameare unaffected (the per-task block is skipped).The driver-log summary also gains a per-task block:
2.
max_val_samplesbecomes optional, Distillation truncation matches GRPOBundled because the patches sit a few lines apart in the same
validate()block.grpo.GRPOConfig.max_val_sampleswas already typed asint | None(grpo.py:150) for NeMo-Gym compatibility but the main path crashed onNone. Distillation's TypedDict requiredintoutright. Both now acceptNone/absent and fall back to the fullval_dataloader.max_val_samplesis not divisible byval_batch_size; all shipped recipes underexamples/configs/recipes/llm/use values that divide cleanly so none are affected.Files touched
nemo_rl/algorithms/grpo.pyvalidate(). Optionalmax_val_samplesbranch. Per-task block in summary print.nemo_rl/algorithms/distillation.pyNotRequired[int]. Truncation switched to floor division.tests/unit/algorithms/test_grpo.pytest_validate_emits_per_task_accuracy_keys,test_validate_iterates_full_dataloader_when_max_val_samples_is_none.tests/unit/algorithms/test_distillation.pytest_validate_floor_divides_max_val_samples_by_val_batch_sizeto guard the GRPO/Distillation parity.Out of scope
dict[str, StatefulDataLoader]architecture (seenemo_rl/algorithms/dpo.py:332-377, prefixvalidation-{dataset_name}). No change needed.val_batch_sizetolen(val_dataset)at setup time, which is specific to its single-batch eval and not something the main paths should adopt.examples/configs/grpo_math_1B.yaml,examples/configs/distillation_math.yaml) keep their explicitmax_val_samplesso the recommended default stays documented.Backwards compatibility
validation/accuracykey unchanged in value.per_task_rewardsends up with one key, you get one extravalidation/accuracy_<the-only-task>metric. Harmless.task_name: per-task block is skipped, behaviour matches the old code.Issues
Closes #2497.
Usage
A recipe with multiple validation datasets now reports per-task metrics:
wandb will plot
validation/accuracy,validation/accuracy_gsm8k, andvalidation/accuracy_ResponseDatasetseparately. The existing aggregatedvalidation/accuracypanel keeps working unchanged.Before your PR is "Ready for review"
Pre checks:
Additional Information
config-conventionsskill: optionalmax_val_samplesfield expressed viaNotRequired, no hidden defaults, exemplar YAMLs keep their explicit recommended values.data-math500→data_math500)? Hyphens work in wandb but render slightly oddly in some downstream stores. Current patch keeps the name verbatim; happy to slugify if preferred.origin/mainin a separate worktree to keep the change isolated from fix: pass trust_remote_code=True to remaining AutoConfig.from_pretrained sites #2496 (fix/autoconfig-trust-remote-code); the two PRs touch disjoint files.